import os
from os.path import join as pjoin
import common.paramUtil as paramUtil
from options.train_options import TrainOptions

from utils.plot_script import *

from networks.networks import *
from networks.trainer import Trainer
from data.dataset import MotionDataset
from scripts.motion_process import *
from torch.utils.data import DataLoader

def animation(data, save_dir):
    data =  train_dataset.inv_transform(data)
    for i in range(len(data)):
        joint_data = data[i]
        joint = recover_from_ric(torch.from_numpy(joint_data).float(),
                                 opt.joints_num).numpy()
        save_path = pjoin(save_dir, "%02d.mp4" %(i))
        plot_3d_motion(save_path, kinematic_chain, joint, title="None", fps=fps, radius=radius)

def create_models(opt):
    encoder = ResNetStyleContentEncoder(e_mid_channels, e_sp_channels,
                                        e_st_channels, opt.sp_use_in)
    generator = Generator(dim_pose, g_channels, opt.dim_style)

    discriminator = None
    patch_discriminator = None
    if opt.do_gan:
        discriminator = Discriminator(d_channels, opt.n_dis_down)

    if opt.do_patch_gan:
        patch_discriminator = PatchDiscriminator(dim_pose, opt.patch_size)
    return encoder, generator, discriminator, patch_discriminator

if __name__ == "__main__":
    parser = TrainOptions()
    opt = parser.parse()

    opt.device = torch.device("cpu" if opt.gpu_id==-1 else "cuda:%d"%(opt.gpu_id) )
    torch.autograd.set_detect_anomaly(True)
    if opt.gpu_id != -1:
        torch.cuda.set_device(opt.gpu_id)

    opt.save_root = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name)
    opt.model_dir = pjoin(opt.save_root, 'model')
    opt.meta_dir = pjoin(opt.save_root, "meta")
    opt.eval_dir = pjoin(opt.save_root, "animation")
    opt.log_dir = pjoin("./log", opt.dataset_name, opt.name)

    os.makedirs(opt.model_dir, exist_ok=True)
    os.makedirs(opt.meta_dir, exist_ok=True)
    os.makedirs(opt.eval_dir, exist_ok=True)
    os.makedirs(opt.log_dir, exist_ok=True)

    if opt.dataset_name == 'aist':
        opt.data_root = '../data/aist'
        opt.data_dir = pjoin(opt.data_root, 'new_joint_vecs')
        opt.joints_num = 22
        dim_pose = 263
        radius = 4
        fps = 30
        kinematic_chain = paramUtil.t2m_kinematic_chain
    else:
        raise Exception("Unsupported data type !~")
    # Encoder
    e_mid_channels = [dim_pose-4, 512, 768, 768]
    e_sp_channels = [768, 512, 256]
    e_st_channels = [768, opt.dim_style]
    # Generator
    g_channels = [e_sp_channels[-1], 768, 512, 384]

    # Discriminator
    d_channels = [dim_pose, 768, 512, 384, 256, 128]

    mean = np.load(pjoin(opt.data_root, "Mean.npy"))
    std = np.load(pjoin(opt.data_root, "Std.npy"))

    train_split_file = pjoin(opt.data_root, "train.txt")
    val_split_file = pjoin(opt.data_root, "val.txt")

    # enc_channels = [1]

    encoder, generator, discriminator, patch_discriminator = create_models(opt)
    all_params = 0
    pc_enc = sum(param.numel() for param in encoder.parameters())
    print(encoder)
    print("Total parameters of encoder net: {}".format(pc_enc))
    all_params += pc_enc

    pc_gen = sum(param.numel() for param in generator.parameters())
    print(generator)
    print("Total parameters of generator: {}".format(pc_gen))
    all_params += pc_gen

    if opt.do_gan:
        pc_dis = sum(param.numel() for param in discriminator.parameters())
        print(discriminator)
        print("Total parameters of discriminator net: {}".format(pc_dis))
        all_params += pc_dis

    if opt.do_patch_gan:
        pc_p_dis = sum(param.numel() for param in patch_discriminator.parameters())
        print(patch_discriminator)
        print("Total parameters of discriminator net: {}".format(pc_p_dis))
        all_params += pc_p_dis

    print('Total parameters of all models: {}'.format(all_params))

    trainer = Trainer(opt, encoder, generator, patch_discriminator, discriminator)
    train_dataset = MotionDataset(opt, mean, std, train_split_file)
    val_dataset = MotionDataset(opt, mean, std, val_split_file)
    train_loader = DataLoader(train_dataset, batch_size=opt.batch_size, num_workers=4,
                              drop_last=True, shuffle=True, pin_memory=True)
    val_loader = DataLoader(val_dataset, batch_size=opt.batch_size, num_workers=4,
                            drop_last=True, shuffle=True, pin_memory=True)
    trainer.train(train_loader, val_loader, animation)